# coding:utf-8


# 目标：通过高德地图的搜索接口，抓取每个城市的所有银行的分行信息
# 思路：1. 在本地mysql中存储有全国各城市名称、编码
#      2. 将城市编码读取到数组中
#      3. 按照数据读取每个编码，组拼URL，通过POST请求访问接口
#      4. 获取xml后解析出我们需要的数据，插入到mysql中


import urllib
import xml.dom.minidom as minidom
import string
import urllib.request
import pymysql

file_name = 'result.txt'  # write result to this file
url_header = 'http://restapi.amap.com/v3/place/text?&keyword=&types=160100&'
url_end = '&citylimit=true&&output=xml&offset=20&page=1&key=c787ae8e49424a657127c3ed64cfe053&extensions=base'
url_amap = 'city='
each_page_rec = 20  # results that displays in one page
xml_file = 'tmp.xml'  # xml filen name


# get html by url and save the data to xml file
def gethtml(url):
    page = urllib.request.urlopen(url)
    html = page.read()
    # print(html)

    try:
        # open xml file and save data to it
        with open(xml_file, 'wb+') as xml_file_handle:
            xml_file_handle.write(html)
    except IOError as err:
        print
        "IO error: " + str(err)
        return -1
    return 0


# phrase data from xml
def parsexml():
    total_rec = 1  # record number

    # open xml file and get data record
    try:
        with open(file_name, 'a') as file_handle:
            dom = minidom.parse(xml_file)
            root = dom.getElementsByTagName("response")  # The function getElementsByTagName returns NodeList.

            for node in root:
                total_rec = node.getElementsByTagName('count')[0].childNodes[0].nodeValue

                pois = node.getElementsByTagName("pois")
                for poi in pois[0].getElementsByTagName('poi'):
                    branch_id = poi.getElementsByTagName("id")[0].childNodes[0].nodeValue
                    branch_name = poi.getElementsByTagName("name")[0].childNodes[0].nodeValue
                    branch_type = poi.getElementsByTagName("type")[0].childNodes[0].nodeValue
                    bank_type = poi.getElementsByTagName("typecode")[0].childNodes[0].nodeValue
                    pname = poi.getElementsByTagName("pname")[0].childNodes[0].nodeValue
                    cityname = poi.getElementsByTagName("cityname")[0].childNodes[0].nodeValue
                    aname = poi.getElementsByTagName("adname")[0].childNodes[0].nodeValue
                    # address = poi.getElementsByTagName("address")[0].childNodes[0].nodeValue
                    # biz_type = poi.getElementsByTagName("biz_type")[0].childNodes[0].nodeValue
                    # tel = poi.getElementsByTagName("tel")[0].childNodes[0].nodeValue
                    # distance = poi.getElementsByTagName("distance")[0].childNodes[0].nodeValue
                    arr = branch_type.split(';')
                    bank_name = arr[-1]
                    sql = "insert into bankinfo(branch_id, branch_name, branch_type, bank_name, bank_type, pname, cityname, aname) values('%s', '%s', '%s', '%s', '%s', '%s', '%s', '%s')" % (
                        branch_id, branch_name.replace('(', '').replace(')', ''), branch_type, bank_name, bank_type,
                        pname, cityname, aname)

                    connection = pymysql.connect(host='192.168.11.23', user='root', passwd='123456', port=30306,
                                                 db='test', charset="utf8")
                    cursor = connection.cursor()
                    try:
                        print(sql)
                        cursor.execute(sql)
                        connection.commit()
                        if cursor.rowcount != 1:
                            raise Exception("数据插入失败%s", sql)
                    finally:
                        connection.close()
                        cursor.close()

    except IOError as err:
        print
        "IO error: " + str(err)

    return total_rec


def frange(start, stop, step=1):
    i = start
    while i < stop:
        yield i
        i += step


def getallcity():
    cityarr = []
    connection = pymysql.connect(host='192.168.11.23', user='root', passwd='123456', port=30306,
                                 db='test', charset="utf8")
    cursor = connection.cursor()
    sql = "select * from region where parent_id in (select id from region where parent_id=0)"
    try:
        cursor.execute(sql)
        rows = cursor.fetchall()
        for row in rows:
            cityarr.append(row[0])
        return cityarr
    finally:
        cursor.close()
        connection.close()
    return cityarr


if __name__ == '__main__':
    cityarr = getallcity()
    for cityId in cityarr:
        url = r'%scity=%s%s' % (url_header, cityId, url_end)
        if gethtml(url) == 0:
            total_record_str = parsexml()
            total_record = int(str(total_record_str))
            if (total_record % each_page_rec) != 0:
                page_number = total_record / each_page_rec + 2
            else:
                page_number = total_record / each_page_rec + 1

            for each_page in frange(2, float(page_number)):
                print
                'parsing page ' + str(each_page) + ' ... ...'
                url = url.replace('page=' + str(each_page - 1), 'page=' + str(each_page))
                print(url)
                gethtml(url)
                parsexml()
        else:
            print
            'error: fail to get xml from amap'